{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Neural Cognitive Diagnosis Model (NCDM)\n",
    "This notebook will show you how to train and use the NCDM. First, we will show how to get the data (here we use a0910 as the dataset). Then we will show how to train a NCDM and perform the parameters persistence. At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n",
    "\n",
    "The script version could be found in [NCDM.py](NCDM.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Preparation\n",
    "Before we process the data, we need to first acquire the dataset which is shown in this [prepare_dataset.ipynb](prepare_dataset.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>user_id</th>\n",
       "      <th>item_id</th>\n",
       "      <th>score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1615</td>\n",
       "      <td>12977</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>782</td>\n",
       "      <td>13124</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1084</td>\n",
       "      <td>16475</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>593</td>\n",
       "      <td>8690</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>127</td>\n",
       "      <td>14225</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   user_id  item_id  score\n",
       "0     1615    12977      1\n",
       "1      782    13124      0\n",
       "2     1084    16475      0\n",
       "3      593     8690      0\n",
       "4      127    14225      1"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the data from files\n",
    "import pandas as pd\n",
    "\n",
    "train_data = pd.read_csv(\"../../data/a0910/train.csv\")\n",
    "valid_data = pd.read_csv(\"../../data/a0910/valid.csv\")\n",
    "test_data = pd.read_csv(\"../../data/a0910/test.csv\")\n",
    "df_item = pd.read_csv(\"../../data/a0910/item.csv\")\n",
    "item2knowledge = {}\n",
    "knowledge_set = set()\n",
    "for i, s in df_item.iterrows():\n",
    "    item_id, knowledge_codes = s['item_id'], list(set(eval(s['knowledge_code'])))\n",
    "    item2knowledge[item_id] = knowledge_codes\n",
    "    knowledge_set.update(knowledge_codes)\n",
    "\n",
    "train_data.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(186049, 25606, 55760)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data), len(valid_data), len(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4128, 17746, 123)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Get basic data info for model initialization\n",
    "import numpy as np\n",
    "user_n = np.max(train_data['user_id'])\n",
    "item_n = np.max([np.max(train_data['item_id']), np.max(valid_data['item_id']), np.max(test_data['item_id'])])\n",
    "knowledge_n = np.max(list(knowledge_set))\n",
    "\n",
    "user_n, item_n, knowledge_n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(<torch.utils.data.dataloader.DataLoader at 0x2b8f1ecee20>,\n",
       " <torch.utils.data.dataloader.DataLoader at 0x2b8f1eba1c0>,\n",
       " <torch.utils.data.dataloader.DataLoader at 0x2b8f1eba3a0>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Transform data to torch Dataloader (i.e., batchify)\n",
    "# batch_size is set to 32\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "batch_size = 32\n",
    "def transform(user, item, item2knowledge, score, batch_size):\n",
    "    knowledge_emb = torch.zeros((len(item), knowledge_n))\n",
    "    for idx in range(len(item)):\n",
    "        knowledge_emb[idx][np.array(item2knowledge[item[idx]]) - 1] = 1.0\n",
    "\n",
    "    data_set = TensorDataset(\n",
    "        torch.tensor(user, dtype=torch.int64) - 1,  # (1, user_n) to (0, user_n-1)\n",
    "        torch.tensor(item, dtype=torch.int64) - 1,  # (1, item_n) to (0, item_n-1)\n",
    "        knowledge_emb,\n",
    "        torch.tensor(score, dtype=torch.float32)\n",
    "    )\n",
    "    return DataLoader(data_set, batch_size=batch_size, shuffle=True)\n",
    "\n",
    "\n",
    "train_set, valid_set, test_set = [\n",
    "    transform(data[\"user_id\"], data[\"item_id\"], item2knowledge, data[\"score\"], batch_size)\n",
    "    for data in [train_data, valid_data, test_data]\n",
    "]\n",
    "\n",
    "train_set, valid_set, test_set"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training and Persistence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.getLogger().setLevel(logging.INFO)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|█████████████████████████████████████████████████████████████████████| 5815/5815 [01:14<00:00, 77.69it/s]\n",
      "Evaluating:   5%|███▍                                                                | 41/801 [00:00<00:02, 356.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] average loss: 0.650993\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|███████████████████████████████████████████████████████████████████| 801/801 [00:02<00:00, 380.23it/s]\n",
      "Epoch 1:   0%|                                                                       | 10/5815 [00:00<01:06, 86.67it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] auc: 0.712843, accuracy: 0.661056\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|█████████████████████████████████████████████████████████████████████| 5815/5815 [01:10<00:00, 82.66it/s]\n",
      "Evaluating:   5%|███▍                                                                | 40/801 [00:00<00:01, 397.00it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 1] average loss: 0.528788\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|███████████████████████████████████████████████████████████████████| 801/801 [00:02<00:00, 399.81it/s]\n",
      "Epoch 2:   0%|                                                                        | 9/5815 [00:00<01:12, 80.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 1] auc: 0.741101, accuracy: 0.722018\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|█████████████████████████████████████████████████████████████████████| 5815/5815 [01:10<00:00, 83.00it/s]\n",
      "Evaluating:   5%|███▍                                                                | 41/801 [00:00<00:01, 394.46it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 2] average loss: 0.466515\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating: 100%|███████████████████████████████████████████████████████████████████| 801/801 [00:02<00:00, 390.05it/s]\n",
      "INFO:root:save parameters to ncdm.snapshot\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 2] auc: 0.749467, accuracy: 0.728032\n"
     ]
    }
   ],
   "source": [
    "from EduCDM import NCDM\n",
    "\n",
    "cdm = NCDM(knowledge_n, item_n, user_n)\n",
    "cdm.train(train_set, valid_set, epoch=3, device=\"cuda\")\n",
    "cdm.save(\"ncdm.snapshot\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading and Testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:load parameters from ncdm.snapshot\n",
      "Evaluating: 100%|█████████████████████████████████████████████████████████████████| 1743/1743 [00:02<00:00, 763.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc: 0.751848, accuracy: 0.730273\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "cdm.load(\"ncdm.snapshot\")\n",
    "auc, accuracy = cdm.eval(test_set)\n",
    "print(\"auc: %.6f, accuracy: %.6f\" % (auc, accuracy))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
